{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c97a7c94",
   "metadata": {},
   "source": [
    "# Bayer2RGB Usage Demonstration For Other Networks\n",
    "\n",
    "With this notebook, you can use a Bayer2RGB model and its quantized checkpoint for debayerization of RGB image and see the accuracy difference between bilinear interpolation and Bayer2RGB debayered image on the \"ai87net-imagenet-effnetv2\" model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a65b4955",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "###################################################################################################\n",
    "#\n",
    "# Copyright © 2023 Analog Devices, Inc. All Rights Reserved.\n",
    "# This software is proprietary and confidential to Analog Devices, Inc. and its licensors.\n",
    "#\n",
    "###################################################################################################import cv2\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "from torch.utils import data\n",
    "import cv2\n",
    "\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "sys.path.append(os.path.dirname(os.getcwd()))\n",
    "sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'models'))\n",
    "\n",
    "import ai8x\n",
    "import parse_qat_yaml\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d336ad92",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(os.path.join(os.getcwd(), './models/'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "582c27d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "b2rgb = importlib.import_module(\"ai85net-bayer2rgbnet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943067bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    def __init__(self, act_mode_8bit):\n",
    "        self.act_mode_8bit = act_mode_8bit\n",
    "        self.truncate_testset = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fc25a42",
   "metadata": {},
   "outputs": [],
   "source": [
    "act_mode_8bit = True # For evaluation mode, input/output range: -128, 127\n",
    "\n",
    "test_batch_size = 1\n",
    "\n",
    "args = Args(act_mode_8bit=act_mode_8bit)\n",
    "\n",
    "checkpoint_path_b2rgb = \"../../ai8x-synthesis/trained/ai85-b2rgb-qat8-q.pth.tar\"\n",
    "\n",
    "qat_yaml_file_used_in_training_b2rgb = '../policies/qat_policy_imagenet.yaml'\n",
    "\n",
    "ai_device = 87\n",
    "round_avg = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5336f275",
   "metadata": {},
   "source": [
    "# imagenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d3ac0fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import imagenet\n",
    "test_model = importlib.import_module('models.ai87net-imagenet-effnetv2')\n",
    "data_path = '/data_ssd/'\n",
    "checkpoint_path = \"../../ai8x-synthesis/trained/ai87-imagenet-effnet2-q.pth.tar\"\n",
    "qat_yaml_file_used_in_training = '../policies/qat_policy_imagenet.yaml'\n",
    "\n",
    "# Dataset used for Biliner Interpolation\n",
    "_, test_set_inter = imagenet.imagenet_bayer_fold_2_get_dataset((data_path, args), load_train=False, load_test=True, fold_ratio=1)\n",
    "\n",
    "# Dataset used for Bayer2RGB Debayerization\n",
    "_, test_set = imagenet.imagenet_bayer_fold_2_get_dataset((data_path, args), load_train=False, load_test=True, fold_ratio=2)\n",
    "\n",
    "# Original dataset\n",
    "_, test_set_original = imagenet.imagenet_get_datasets((data_path, args), load_train=False, load_test=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acad3fd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataloader_inter = data.DataLoader(test_set_inter, batch_size=test_batch_size, shuffle=False)\n",
    "test_dataloader = data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False)\n",
    "test_dataloader_original = data.DataLoader(test_set_original, batch_size=test_batch_size, shuffle=False)\n",
    "print(len(test_dataloader))\n",
    "print(len(test_dataloader.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc02416",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "qat_policy_bayer2rgb = parse_qat_yaml.parse(qat_yaml_file_used_in_training_b2rgb)\n",
    "qat_policy = parse_qat_yaml.parse(qat_yaml_file_used_in_training)\n",
    "\n",
    "ai8x.set_device(device=ai_device, simulate=act_mode_8bit, round_avg=round_avg)\n",
    "\n",
    "model_bayer2rgb = b2rgb.bayer2rgbnet().to(device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa8e4c3a",
   "metadata": {},
   "source": [
    "Run one of the following models according to the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c308ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = test_model.AI87ImageNetEfficientNetV2(bias=\"--use-bias\").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1579d9d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26533322",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_bayer2rgb.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37758e6b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# fuse the BN parameters into conv layers before Quantization Aware Training (QAT)\n",
    "ai8x.fuse_bn_layers(model_bayer2rgb)\n",
    "ai8x.fuse_bn_layers(model)\n",
    "\n",
    "# switch model from unquantized to quantized for QAT\n",
    "ai8x.initiate_qat(model_bayer2rgb, qat_policy_bayer2rgb)\n",
    "ai8x.initiate_qat(model, qat_policy)\n",
    "\n",
    "checkpoint_b2rgb = torch.load(checkpoint_path_b2rgb,map_location = device)\n",
    "checkpoint = torch.load(checkpoint_path,map_location = device)\n",
    "\n",
    "model_bayer2rgb.load_state_dict(checkpoint_b2rgb['state_dict'], strict=False)\n",
    "model.load_state_dict(checkpoint['state_dict'], strict=False)\n",
    "\n",
    "ai8x.update_model(model_bayer2rgb)\n",
    "model_bayer2rgb = model_bayer2rgb.to(device)\n",
    "ai8x.update_model(model)\n",
    "model = model.to(device)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "79f325e1",
   "metadata": {},
   "source": [
    "# Bayer-to-RGB + AI87ImageNetEfficientNetV2 Model\n",
    "Bayer2RGB model is used before AI87ImageNetEfficientNetV2 to obtain RGB images from bayered images and then AI87ImageNetEfficientNetV2 model is evaluated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d513cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_bayer2rgb.eval()\n",
    "model.eval()\n",
    "correct = 0\n",
    "with torch.no_grad():\n",
    "    for (image1, label1), (image2, label2) in zip(test_dataloader, test_dataloader_original):\n",
    "        image = image1.to(device)\n",
    "\n",
    "        primer_out = model_bayer2rgb(image)\n",
    "\n",
    "        model_out = model(primer_out)\n",
    "        result = np.argmax(model_out.cpu())\n",
    "\n",
    "        if(label2 == result):\n",
    "            correct = correct + 1 \n",
    "        if correct % 15 == 0:\n",
    "            print(\"accuracy:\")\n",
    "            print(correct / len(test_set))\n",
    "\n",
    "print(\"accuracy:\")\n",
    "print(correct / len(test_set))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "abc140e4",
   "metadata": {},
   "source": [
    "# Model\n",
    "Original Dataset is used to evaluate AI87ImageNetEfficientNetV2 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "173f9d5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "correct = 0\n",
    "with torch.no_grad():\n",
    "    for image, label in test_dataloader_original:\n",
    "        image = image.to(device)\n",
    "        model_out = model(image)\n",
    "        result = np.argmax(model_out.cpu())\n",
    "\n",
    "        if(label == result):\n",
    "            correct = correct + 1\n",
    "\n",
    "        if correct % 15 == 0:\n",
    "            print(\"accuracy:\")\n",
    "            print(correct / len(test_set))\n",
    "\n",
    "print(\"accuracy:\")\n",
    "print(correct / len(test_set))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bea82bf8",
   "metadata": {},
   "source": [
    "# Bilinear Interpolation + Model\n",
    "Bilinear Interpolation is used before AI87ImageNetEfficientNetV2 to obtain RGB images from bayered images and then  model is evaluated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5fe86ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "correct = 0\n",
    "with torch.no_grad():\n",
    "    for (image1, label1), (image2, label2) in zip(test_dataloader_inter, test_dataloader_original):\n",
    "        image = image1.to(device)\n",
    "\n",
    "        img = (128+(image[0].cpu().detach().numpy().transpose(1,2,0))).astype(np.uint8)\n",
    "        img = cv2.cvtColor(img,cv2.COLOR_BayerGR2RGB)\n",
    "\n",
    "        out_tensor = torch.Tensor(((img.transpose(2,0,1).astype(np.float32))/128-1)).to(device)\n",
    "        out_tensor = out_tensor.unsqueeze(0)\n",
    "        model_out = model(out_tensor)\n",
    "        result = np.argmax(model_out.cpu())\n",
    "\n",
    "        if(label2 == result):\n",
    "            correct = correct + 1\n",
    "\n",
    "        if correct % 15 == 0:\n",
    "            print(\"accuracy:\")\n",
    "            print(correct / len(test_set))\n",
    "\n",
    "print(\"accuracy:\")\n",
    "print(correct / len(test_set))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "vscode": {
   "interpreter": {
    "hash": "570feb405e2e27c949193ac68f46852414290d515b0ba6e5d90d076ed2284471"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
